from sqlalchemy import func

from app.models import db
from app.models.train_station_lib import TrainStation

class Train(db.Model):
    __tablename__: str = 'train'

    id = db.Column(db.Integer, primary_key=True)
    train_no = db.Column(db.String(120), unique=True, nullable=False, index=True)
    departure_station = db.Column(db.String(120))
    arrival_station = db.Column(db.String(120))
    departure_time = db.Column(db.DateTime)
    expiration_time = db.Column(db.DateTime)
    effective_time = db.Column(db.DateTime)
    arrival_time = db.Column(db.DateTime)
    created_at = db.Column(db.DateTime, default=func.now())
    updated_at = db.Column(db.DateTime, default=func.now())

    def __repr__(self):
        return f'<Train {self.id}>'

    @classmethod
    def create(cls, new_train):
        db.session.add(new_train)
        db.session.commit()
        return new_train

    @classmethod
    def queryTrains(cls, from_station, to_station, date):
        # Query for train stations where the station name matches `from_station`
        from_train = TrainStation.query.filter_by(station_name=from_station).all()

        # Query for train stations where the station name matches `to_station`
        to_train = TrainStation.query.filter_by(station_name=to_station).all()

        # Extract train_no from both query results
        from_train_nos = {ts.train_no for ts in from_train}
        to_train_nos = {ts.train_no for ts in to_train}

        # Find the common train_no between the two stations
        common_train_nos = from_train_nos & to_train_nos

        # Filter train numbers where the index of the from station is less than the index of the to station
        valid_train_nos = [
            train_no for train_no in common_train_nos
            if next(ts.index for ts in from_train if ts.train_no == train_no) <
               next(ts.index for ts in to_train if ts.train_no == train_no)
        ]
        # Query trains by the filtered train numbers and the given date (assuming date filtering)
        trains = Train.query.filter(
            Train.effective_time >= date,
            Train.train_no.in_(valid_train_nos)
        ).all()

        # Assuming you have a presenter or serializer for the trains
        return trains


def buildTrain(params):
    # Create a new Train object
    train = Train(
        train_no=params['trainNo'],
        effective_time=params['effective_time'],
        expiration_time=params['expiration_time']
    )

    # Extract indexes for determining first and last station
    indexes = [e["index"] for e in params['stations']]

    for e in params['stations']:
        # Create and associate TrainStation objects
        train_station = TrainStation(
            station_name=e["name"],
            price=e["price"],
            departure_time=e["depTime"],
            arrival_time=e["arrTime"],
            index=e["index"]
        )
        train.train_stations.append(train_station)
        # Determine the departure time and station for the first station
        if e["index"] == 0:
            train.departure_time = e["depTime"]
            train.departure_station = e["name"]

        # Determine the arrival time and station for the last station
        if e["index"] == max(indexes):
            train.arrival_time = e["arrTime"]
            train.arrival_station = e["name"]

    return train
